{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2iYLQO5Evvqy"
      },
      "source": [
        "# Training a `Robust' Adapter with AdapterDrop\n",
        "\n",
        "This notebook extends our quickstart adapter training notebook to illustrate how we can use AdapterDrop\n",
        "to robustly train an adapter, i.e. adapters that allow us to dynamically drop layers for faster multi-task inference.\n",
        "Please have a look at the original adapter training notebook for more details on the setup."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a-XTIOLv0isn"
      },
      "source": [
        "## Installation\n",
        "\n",
        "First, let's install the required libraries:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ju-alwbHmKYA",
        "outputId": "e27a4b67-23a8-4f64-d79e-0438c20aeb90",
        "ExecuteTime": {
          "end_time": "2023-08-17T10:22:15.568306600Z",
          "start_time": "2023-08-17T10:22:10.046264200Z"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m204.3/204.3 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m114.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m30.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m107.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m53.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.3/519.3 kB\u001b[0m \u001b[31m8.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m14.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m17.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.2/251.2 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install -Uq adapters\n",
        "!pip install -q datasets\n",
        "!pip install -Uq accelerate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 433,
          "referenced_widgets": [
            "af31afecb7c74452a6742b1d27c72e6a",
            "446f1223019548148082914b79118b20",
            "2b22b38b0fc4452fb7357ef42fbe7d1f",
            "eb96086559bd42ba92757c5c610d8df8",
            "f091264355bd49bd878bb504a148b5aa",
            "fa07ebd07a1b4efb902120e862bd9e51",
            "37b99bd25208450f887a8f4895c26db8",
            "521881c77f4045939dab280c6ade1a6d",
            "224f50e75056483cadef731e5ac67c5f",
            "f6f86104b6174f7f88b08b07b3c0d24e",
            "4acc59da308b4ce3a1c88ac9c5816063",
            "f3f34b5b27804d4285e25868b50c1a3d",
            "f53531d2459b47488905a6177cf35afe",
            "f68cf0bbd5554b8d95e3ce62aa94b5f8",
            "935e0d5dd86c44dd9fe050d8af871f1e",
            "ddb7423eda4242b2944415709bac3b0e",
            "ae2d0e38b01747f08cbe17193e054c07",
            "a61f5114dc3540978e56be4edff7c74e",
            "b11852d1ec5e48b4927cf7e373f7cf3b",
            "fe951e22c380445f91cd23f6d584051a",
            "7faa414272e64e5a877f9d4da104efb4",
            "aa6d91e16f69487c8a570f0bc383f5f7",
            "5781d448c6dd4eac803e55698696dc92",
            "710384b099c948d4aa6fe7d57b88a677",
            "3a3401331523485d8151355cf444d6a5",
            "7ada3e3a52af4bce904a277112eef57a",
            "838b8a4b3d194ada9018a1d6cacf1aed",
            "b7c77b6e6383451f93cbc01784c9dcd6",
            "afe656aa689c4e01ad42d96adac7a00d",
            "6743b01effec4703bdd267c68904ed0b",
            "528bce3bacdc49c4902db26bcd47f053",
            "a49e848bd28d4b8ba03dd7e1372ba20e",
            "4df10860cebb4e9b93cba9050edf7ea7",
            "01bec9a46d4a4ba7a616f9a373ab5841",
            "2b70acbf1a3d4049a60f9772cfb04fda",
            "bca32b6d93ca4a80a997fce497228935",
            "ca005388a42648cbaf9a299285417ce5",
            "cefe0372d1bb4a86939434660bc0e4c9",
            "d8745bfddd07481d9603f9051c20bddf",
            "f15283b532cc44838d6f137a27c67ca9",
            "6f2a3ba4b1a041d3ab8517d806df2aba",
            "4a333dcb24b14ff786dbc5a8a18a6f66",
            "d6485841d1b644a0a8fad8bc22a03ad3",
            "804dcfd16bf34377923a00e6e089cb45",
            "d706983dad4c44b58a68c9c2e08155e8",
            "07a2ae8eb8a04dcfaa716f7fc6e3fe5a",
            "63796b740f6a4535bfcc7597b28ddcbf",
            "caea229a178c4f2e82fe640944565c5d",
            "4d712e71fdfe4543bd3c39e63826bd04",
            "61a8cab1c2834a17bf7f2323532a57c9",
            "0b451796189b4b24b0925fb073fa1ab1",
            "6999db080b9a42febbafcbe0eb5d5686",
            "f5d8e7f0393542098ef16cdabeae82a0",
            "d6a92b96815b40ee9644f92a55a4a5cf",
            "7c413729baa44fa98f664066ab2572f3",
            "647fc145ce124ae593a38b64917fe7e2",
            "88b58a7a6ba044f58171588c25ed5813",
            "8acf1ae378fe407dac9331a5075089b5",
            "9a3e144b2f994d9ca8ea2568c2dacc51",
            "555443e615cf4f7baa964501d6d6b7d1",
            "e0109d1a0f084ad7b35d7368e24325be",
            "1c29394bd2f44666a1fe51a90f79cce8",
            "5a9e3a42ff824b4c8830951d8bc32ef4",
            "adb65e2d275d4b6292b30a24469bfb9e",
            "24ad809919bb41948aee62b179145d1e",
            "4a4cb888d5244402a07031097a9dc9cf",
            "aa3380796dca468ea599f339d4d2f9ad",
            "cd05a336f8be43abaf98df1baafccded",
            "97716a04b3674b4a9c0cd7d6bdb922ca",
            "278dc0daa494489da9bf1951c62b79ab",
            "99e5e703835a4074adb0099973baa60b",
            "eae00629dfb342fd813a049c0e435902",
            "479d6068ec064a62aa1c6f0dd0cfe06a",
            "22f0a4967c6440ce9f8098216e9bcf77",
            "e31fa4e730bc4215b63e370e736a22b9",
            "5e5f2e6cb8b842a6afdade096f6ab7dc",
            "b1984f751a864efe8618e0bf964eb78e",
            "b054c540d6244a218c30698dfb7e38f9",
            "3e7e17b4e44c47c08b7fd170cfc968a5",
            "23f7713745d64f4da5d6bc2ec9519af2",
            "d21ae5c21b774573a67e8494f5147d2c",
            "11155f5f68734c0b9ae679c4153eae7d",
            "58a9b612d4a3412797cefac7ef6073e9",
            "965fae3b75a04508a0d4bec8a07d7169",
            "e643943782d14bdcbadd848d7c966722",
            "b2b51d07f5314af29a24c6737149e8f8",
            "5597e5a3a5b444cebe33a7b100456f37",
            "41506f06950d4ff2ba9ceb5c85d4a347",
            "04d83fe69dd0462a9ff4ada0eac80e40",
            "e43f0be0fa684d15be0aff1ac232506e",
            "d11a4801b1f342d7958c8b37a77a3cbf",
            "968a07cc377e4d2e845db9ee866cd024",
            "c5b406dba16b478ea440a21fa10eb888",
            "e6346d58d75348198f558d81f1764148",
            "cb525c89ce1647e2bfa1485383710b4f",
            "5b6d6bd13ca84a419a3ada9944a20be5",
            "4e209532b4cc452c86bf6c0596ce400f",
            "4f2ca0fa250048a8b020106388d777c8",
            "5bf1aa6c7ad9414aac8771debbbcf8d1",
            "b6c664e17842414d977add2db60e0538",
            "76f718e53ff8453e926a7ce31058d8e6",
            "1e88c0c9c6b44e5091eb205f2dd18c7f",
            "de45dccfa7624dbc906ea9d992018001",
            "2cd930069cc240bca730116780514a7b",
            "019567da24b848e4a4adb1ddf8bc8c77",
            "ea8c04981ab64336913a54a3cffb72f0",
            "ea840af8a18843ea85d67a0e33babe1d",
            "8492e3eb00b94f65bde5613934a048ce",
            "8149b1bb97ce46598700352ff0381ff6",
            "75d35b3cafc1467484aa082ca6aa7bf4",
            "213d0fffca2942268119ac6a343132f5",
            "6e6f97d419c44318ad4ef9352e6a3537",
            "0ffb7dea03dd41b8b8884f620a7c6c0f",
            "f6c66249c1704e7fa4544ac82120d16c",
            "dd957dd0ed45473b850ef005d6f763c8",
            "05e7abcc40fc4923bc20283de79b2cd8",
            "3cd028423c5149e187673311ec7e059c",
            "3a917b30da3746d5b05eaedce4ee68d4",
            "53d53ebb2699469bb55a29001b73dce4",
            "51b43bedd258407ba7dde4f852b5f82b",
            "09bfb00f1e82482cbd8e894c34e0ead9",
            "989c79a2b4ef4d0a86e36daa5bfda591",
            "15a1cdc48e844da9804d1b55955b1c11",
            "4b3299ed51ea4eeebbdb3e2381b8eb42",
            "6813beffe74f45358df6eaf9fc10a716",
            "79c1889e0d1e4e12a6ff9c84dfad51e4",
            "ca694747816f4359b84a2d0702d44e87",
            "5f38595c519c4ba7b727f3cd31cdc6ed",
            "5c2e4716d30c482b81ecc7c9ed937e9f",
            "95f6805159124b9b9664b8dc8bc96114",
            "db4e6cf63b764c199b163e1fb47bfeda",
            "28e09cc6d64f4dce90ebf5ac3993938c",
            "ffffac8ac58c4d9f8b8c045928310a2b",
            "21308b05960d4f97907fe4328db3f43c",
            "446ed0561a404d759928582e5b0b973b",
            "e4ffd22b0d1342d9abf234259401ea3f",
            "86473dc1d87741c78e5517ecaa642a94",
            "beff631b267749aa8a7ad564ee1f5ef9",
            "fb9edcff8b634e7e9222a0d505684c1c",
            "484e4b9638224bfdbc70e95935cb6ab2",
            "575da7b329f74a758e364ca049c888f1",
            "1e71033491304165a7fa71c84e4e721f",
            "4633e1569728411c82e118e766641fd4"
          ]
        },
        "id": "7Mx916lBCfoL",
        "outputId": "8b7ede9c-215a-4a75-e9c3-ccd92824a053",
        "pycharm": {
          "name": "#%%\n"
        },
        "ExecuteTime": {
          "end_time": "2023-08-17T10:22:24.609863400Z",
          "start_time": "2023-08-17T10:22:15.568306600Z"
        }
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading builder script:   0%|          | 0.00/5.03k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "af31afecb7c74452a6742b1d27c72e6a"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading metadata:   0%|          | 0.00/2.02k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "f3f34b5b27804d4285e25868b50c1a3d"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading readme:   0%|          | 0.00/7.25k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "5781d448c6dd4eac803e55698696dc92"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading data:   0%|          | 0.00/488k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "01bec9a46d4a4ba7a616f9a373ab5841"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating train split:   0%|          | 0/8530 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "d706983dad4c44b58a68c9c2e08155e8"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating validation split:   0%|          | 0/1066 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "647fc145ce124ae593a38b64917fe7e2"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating test split:   0%|          | 0/1066 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "aa3380796dca468ea599f339d4d2f9ad"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "b054c540d6244a218c30698dfb7e38f9"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "04d83fe69dd0462a9ff4ada0eac80e40"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)lve/main/config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "b6c664e17842414d977add2db60e0538"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Map:   0%|          | 0/8530 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "213d0fffca2942268119ac6a343132f5"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Map:   0%|          | 0/1066 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "989c79a2b4ef4d0a86e36daa5bfda591"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Map:   0%|          | 0/1066 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ffffac8ac58c4d9f8b8c045928310a2b"
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "from datasets import load_dataset\n",
        "from transformers import RobertaTokenizer\n",
        "\n",
        "dataset = load_dataset(\"rotten_tomatoes\")\n",
        "tokenizer = RobertaTokenizer.from_pretrained(\"roberta-base\")\n",
        "\n",
        "def encode_batch(batch):\n",
        "  \"\"\"Encodes a batch of input data using the model tokenizer.\"\"\"\n",
        "  return tokenizer(batch[\"text\"], max_length=80, truncation=True, padding=\"max_length\")\n",
        "\n",
        "# Encode the input data\n",
        "dataset = dataset.map(encode_batch, batched=True)\n",
        "# The transformers model expects the target class column to be named \"labels\"\n",
        "dataset = dataset.rename_column(\"label\", \"labels\")\n",
        "# Transform to pytorch tensors and only output the required columns\n",
        "dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S2-2CbfPGYvi"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84,
          "referenced_widgets": [
            "f7336b59487a47eaaa878f0089339070",
            "940c9b83e7a9451a97c818602117335d",
            "f324d6a61dc44bec84890fa940fd9d99",
            "df641f79fce74d04a607fac20045d8b5",
            "53e6cef9a13940c9977dbe76c4bec5e8",
            "211f84c7f6244dd2aaa943bc4a1eaf57",
            "26bea3bcc18f481784420455e145d9f5",
            "c582f5e22b17436588dca01e5d82ada9",
            "2a659681ce8a4d5290a94b24e4864bab",
            "160a6c88ead840b0953bcce0eda8b38b",
            "7b66541cd1354b25b75bf99a2947355d"
          ]
        },
        "id": "Tp9uG-pT-qgv",
        "outputId": "7ecbca2e-8ee4-4181-8c2a-20c6468d51d7",
        "ExecuteTime": {
          "end_time": "2023-08-17T10:22:28.765981800Z",
          "start_time": "2023-08-17T10:22:24.609863400Z"
        }
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "f7336b59487a47eaaa878f0089339070"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Some weights of RobertaAdapterModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        }
      ],
      "source": [
        "from adapters import RobertaAdapterModel\n",
        "from transformers import RobertaConfig\n",
        "\n",
        "config = RobertaConfig.from_pretrained(\n",
        "    \"roberta-base\",\n",
        "    num_labels=2,\n",
        "    id2label={ 0: \"👎\", 1: \"👍\"},\n",
        ")\n",
        "model = RobertaAdapterModel.from_pretrained(\n",
        "    \"roberta-base\",\n",
        "    config=config,\n",
        ")\n",
        "\n",
        "# Add a new adapter\n",
        "model.add_adapter(\"rotten_tomatoes\")\n",
        "# Add a matching classification head\n",
        "model.add_classification_head(\"rotten_tomatoes\", num_labels=2)\n",
        "# Activate the adapter\n",
        "model.train_adapter(\"rotten_tomatoes\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ev5t_8i8HzJB"
      },
      "source": [
        "To dynamically drop adapter layers during training, we make use of HuggingFace's `TrainerCallback'."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "5FRft_5AAlQd",
        "ExecuteTime": {
          "end_time": "2023-08-17T10:22:29.060360100Z",
          "start_time": "2023-08-17T10:22:28.765981800Z"
        }
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from adapters import AdapterTrainer\n",
        "from transformers import TrainingArguments, EvalPrediction, TrainerCallback\n",
        "\n",
        "class AdapterDropTrainerCallback(TrainerCallback):\n",
        "  def on_step_begin(self, args, state, control, **kwargs):\n",
        "    skip_layers = list(range(np.random.randint(0, 11)))\n",
        "    kwargs['model'].set_active_adapters(\"rotten_tomatoes\", skip_layers=skip_layers)\n",
        "\n",
        "  def on_evaluate(self, args, state, control, **kwargs):\n",
        "    # Deactivate skipping layers during evaluation (otherwise it would use the\n",
        "    # previous randomly chosen skip_layers and thus yield results not comparable\n",
        "    # across different epochs)\n",
        "    kwargs['model'].set_active_adapters(\"rotten_tomatoes\", skip_layers=None)\n",
        "\n",
        "\n",
        "training_args = TrainingArguments(\n",
        "    learning_rate=1e-4,\n",
        "    num_train_epochs=6,\n",
        "    per_device_train_batch_size=32,\n",
        "    per_device_eval_batch_size=32,\n",
        "    logging_steps=200,\n",
        "    output_dir=\"./training_output\",\n",
        "    overwrite_output_dir=True,\n",
        "    remove_unused_columns=False\n",
        ")\n",
        "\n",
        "def compute_accuracy(p: EvalPrediction):\n",
        "  preds = np.argmax(p.predictions, axis=1)\n",
        "  return {\"acc\": (preds == p.label_ids).mean()}\n",
        "\n",
        "trainer = AdapterTrainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=dataset[\"train\"],\n",
        "    eval_dataset=dataset[\"validation\"],\n",
        "    compute_metrics=compute_accuracy,\n",
        ")\n",
        "\n",
        "trainer.add_callback(AdapterDropTrainerCallback())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9iHhoYuLIdX3",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "We can now train and evaluate our robustly trained adapter!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 505
        },
        "id": "zZxaujENntNR",
        "outputId": "cc885733-f6f3-4396-e723-61aace9614c1",
        "pycharm": {
          "name": "#%%\n"
        },
        "ExecuteTime": {
          "end_time": "2023-08-17T10:26:14.257776400Z",
          "start_time": "2023-08-17T10:22:29.060360100Z"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='1602' max='1602' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [1602/1602 05:40, Epoch 6/6]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>200</td>\n",
              "      <td>0.556100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>400</td>\n",
              "      <td>0.362700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>600</td>\n",
              "      <td>0.328000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>800</td>\n",
              "      <td>0.315600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1000</td>\n",
              "      <td>0.295300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1200</td>\n",
              "      <td>0.287100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1400</td>\n",
              "      <td>0.302800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1600</td>\n",
              "      <td>0.277800</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='34' max='34' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [34/34 00:04]\n",
              "    </div>\n",
              "    "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'eval_loss': 0.2743206322193146,\n",
              " 'eval_acc': 0.8893058161350844,\n",
              " 'eval_runtime': 4.3565,\n",
              " 'eval_samples_per_second': 244.694,\n",
              " 'eval_steps_per_second': 7.804,\n",
              " 'epoch': 6.0}"
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ],
      "source": [
        "trainer.train()\n",
        "trainer.evaluate()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "toc_visible": true,
      "gpuType": "T4"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.6"
    },
    "pycharm": {
      "stem_cell": {
        "cell_type": "raw",
        "metadata": {
          "collapsed": false
        },
        "source": []
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
